import torch
import torch.nn as nn
import pandas as pd
import MetaTrader5 as mt5

# Custom Transformer for tokenizing time
class TimeTokenizer(nn.Module):
    def forward(self, X):
        time_column = X[:, 0]  # Assuming 'time' is the first column
        time_token = (time_column % 86400) / 86400
        time_token = time_token.unsqueeze(1)  # Add a dimension to match the input shape
        return torch.cat((time_token, X[:, 1:]), dim=1)  # Concatenate the time token with the rest of the input

# Custom Transformer for daily rolling normalization
class DailyRollingNormalizer(nn.Module):
    def forward(self, X):
        time_tokens = X[:, 0]  # Assuming 'time_token' is the first column
        price_columns = X[:, 1:]  # Assuming 'open', 'high', 'low', 'close' are the remaining columns

        normalized_price_columns = torch.zeros_like(price_columns)
        rolling_max = price_columns.clone()
        rolling_min = price_columns.clone()

        for i in range(1, price_columns.shape[0]):
            reset_mask = (time_tokens[i] < time_tokens[i-1]).float()
            rolling_max[i] = reset_mask * price_columns[i] + (1 - reset_mask) * torch.maximum(rolling_max[i-1], price_columns[i])
            rolling_min[i] = reset_mask * price_columns[i] + (1 - reset_mask) * torch.minimum(rolling_min[i-1], price_columns[i])
            denominator = rolling_max[i] - rolling_min[i]
            normalized_price_columns[i] = (price_columns[i] - rolling_min[i]) / denominator

        time_tokens = time_tokens.unsqueeze(1)  # Assuming 'time_token' is the first column
        return torch.cat((time_tokens, normalized_price_columns), dim=1)

class ReplaceNaNs(nn.Module):
    def forward(self, X):
        X[torch.isnan(X)] = 0
        X[X != X] = 0  # replace negative NaNs with 0
        return X

# Connect to MetaTrader 5
if not mt5.initialize():
    print("Initialize failed")
    mt5.shutdown()

# Load market data (reduced sample size for demonstration)
symbol = "EURUSD"
timeframe = mt5.TIMEFRAME_M15
rates = mt5.copy_rates_from_pos(symbol, timeframe, 0, 160) #intialize with maximum number of bars allowed by your broker
mt5.shutdown()

# Convert to DataFrame and keep only 'time', 'open', 'high', 'low', 'close' columns
data = pd.DataFrame(rates)[['time', 'open', 'high', 'low', 'close']]

# Convert the DataFrame to a PyTorch tensor
data_tensor = torch.tensor(data.values, dtype=torch.float32)

# Create the updated pipeline
pipeline = nn.Sequential(
    TimeTokenizer(),
    DailyRollingNormalizer(),
    ReplaceNaNs()
)

# Print the data before processing
print('Data Before Processing\n', data[:100])

# Process the data
processed_data = pipeline(data_tensor)

print('Data After Processing\n', processed_data[:100])

# Export the pipeline to ONNX format
dummy_input = torch.randn(len(data), len(data.columns))
torch.onnx.export(pipeline, dummy_input, "data_processing_pipeline.onnx", input_names=["input"], output_names=["output"])
